{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CORAL MLP for ordinal regression and deep learning -- cement strength dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial explains how to train a deep neural network (here: multilayer perceptron) with the CORAL layer and loss function for ordinal regression. \n",
    "\n",
    "**CORAL reference:**\n",
    "\n",
    "- Wenzhi Cao, Vahid Mirjalili, and Sebastian Raschka (2020) \n",
    "[Rank Consistent Ordinal Regression for Neural Networks with Application to Age Estimation](https://www.sciencedirect.com/science/article/pii/S016786552030413X) \n",
    "Pattern Recognition Letters. 140, 325-331"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0 -- Obtaining and preparing the cement_strength dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will be using the cement_strength dataset from [https://github.com/gagolews/ordinal_regression_data/blob/master/cement_strength.csv](https://github.com/gagolews/ordinal_regression_data/blob/master/cement_strength.csv).\n",
    "\n",
    "First, we are going to download and prepare the and save it as CSV files locally. This is a general procedure that is not specific to CORN.\n",
    "\n",
    "This dataset has 5 ordinal labels (1, 2, 3, 4, and 5). Note that we require labels to be starting at 0, which is why we subtract \"1\" from the label column."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of features: 8\n",
      "Number of examples: 998\n",
      "Labels: [0 1 2 3 4]\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "data_df = pd.read_csv(\"https://raw.githubusercontent.com/gagolews/ordinal_regression_data/master/cement_strength.csv\")\n",
    "\n",
    "data_df[\"response\"] = data_df[\"response\"]-1 # labels should start at 0\n",
    "\n",
    "data_labels = data_df[\"response\"]\n",
    "data_features = data_df.loc[:, [\"V1\", \"V2\", \"V3\", \"V4\", \"V5\", \"V6\", \"V7\", \"V8\"]]\n",
    "\n",
    "print('Number of features:', data_features.shape[1])\n",
    "print('Number of examples:', data_features.shape[0])\n",
    "print('Labels:', np.unique(data_labels.values))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Split into training and test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(\n",
    "    data_features.values,\n",
    "    data_labels.values,\n",
    "    test_size=0.2,\n",
    "    random_state=1,\n",
    "    stratify=data_labels.values)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Standardize features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "\n",
    "sc = StandardScaler()\n",
    "X_train_std = sc.fit_transform(X_train)\n",
    "X_test_std = sc.transform(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 -- Setting up the dataset and dataloader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section, we set up the data set and data loaders. This is a general procedure that is not specific to the method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training on cuda:0\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Hyperparameters\n",
    "random_seed = 1\n",
    "learning_rate = 0.05\n",
    "num_epochs = 50\n",
    "batch_size = 128\n",
    "\n",
    "# Architecture\n",
    "NUM_CLASSES = 10\n",
    "\n",
    "# Other\n",
    "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print('Training on', DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "\n",
    "\n",
    "class MyDataset(Dataset):\n",
    "\n",
    "    def __init__(self, feature_array, label_array, dtype=np.float32):\n",
    "    \n",
    "        self.features = feature_array.astype(np.float32)\n",
    "        self.labels = label_array\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        inputs = self.features[index]\n",
    "        label = self.labels[index]\n",
    "        return inputs, label\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.labels.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input batch dimensions: torch.Size([128, 8])\n",
      "Input label dimensions: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "train_dataset = MyDataset(X_train_std, y_train)\n",
    "test_dataset = MyDataset(X_test_std, y_test)\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset,\n",
    "                          batch_size=batch_size,\n",
    "                          shuffle=True, # want to shuffle the dataset\n",
    "                          num_workers=0) # number processes/CPUs to use\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset,\n",
    "                         batch_size=batch_size,\n",
    "                         shuffle=False,\n",
    "                         num_workers=0)\n",
    "\n",
    "# Checking the dataset\n",
    "for inputs, labels in train_loader:  \n",
    "    print('Input batch dimensions:', inputs.shape)\n",
    "    print('Input label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 - Equipping MLP with CORAL layer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section, we are using the CoralLayer implemented in `coral_pytorch` to outfit a multilayer perceptron for ordinal regression. Note that the CORAL method only requires replacing the last (output) layer, which is typically a fully-connected layer, by the CORAL layer.\n",
    "\n",
    "Also, please use the `sigmoid` not softmax function (since the CORAL method uses a concept known as extended binary classification as described in the paper)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#!pip install coral-pytorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from coral_pytorch.layers import CoralLayer\n",
    "\n",
    "\n",
    "class CoralLayer(torch.nn.Module):\n",
    "    \"\"\" Implements CORAL layer described in\n",
    "\n",
    "    Cao, Mirjalili, and Raschka (2020)\n",
    "    *Rank Consistent Ordinal Regression for Neural Networks\n",
    "       with Application to Age Estimation*\n",
    "    Pattern Recognition Letters, https://doi.org/10.1016/j.patrec.2020.11.008\n",
    "\n",
    "    Parameters\n",
    "    -----------\n",
    "    size_in : int\n",
    "        Number of input features for the inputs to the forward method, which\n",
    "        are expected to have shape=(num_examples, num_features).\n",
    "\n",
    "    num_classes : int\n",
    "        Number of classes in the dataset.\n",
    "\n",
    "    preinit_bias : bool (default=True)\n",
    "        If true, it will pre-initialize the biases to descending values in\n",
    "        [0, 1] range instead of initializing it to all zeros. This pre-\n",
    "        initialization scheme results in faster learning and better\n",
    "        generalization performance in practice.\n",
    "\n",
    "\n",
    "    \"\"\"\n",
    "    def __init__(self, size_in, num_classes, preinit_bias=True):\n",
    "        super().__init__()\n",
    "        self.size_in, self.size_out = size_in, 1\n",
    "\n",
    "        self.coral_weights = torch.nn.Linear(self.size_in, 1, bias=False)\n",
    "        if preinit_bias:\n",
    "            self.coral_bias = torch.nn.Parameter(\n",
    "                torch.arange(num_classes - 1, 0, -1).float() / (num_classes-1))\n",
    "        else:\n",
    "            self.coral_bias = torch.nn.Parameter(\n",
    "                torch.zeros(num_classes-1).float())\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        Computes forward pass.\n",
    "\n",
    "        Parameters\n",
    "        -----------\n",
    "        x : torch.tensor, shape=(num_examples, num_features)\n",
    "            Input features.\n",
    "\n",
    "        Returns\n",
    "        -----------\n",
    "        logits : torch.tensor, shape=(num_examples, num_classes-1)\n",
    "        \"\"\"\n",
    "        return self.coral_weights(x) + self.coral_bias\n",
    "\n",
    "\n",
    "\n",
    "class MLP(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, in_features, num_classes, num_hidden_1=300, num_hidden_2=300):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.my_network = torch.nn.Sequential(\n",
    "            \n",
    "            # 1st hidden layer\n",
    "            torch.nn.Linear(in_features, num_hidden_1, bias=False),\n",
    "            torch.nn.LeakyReLU(),\n",
    "            torch.nn.Dropout(0.2),\n",
    "            torch.nn.BatchNorm1d(num_hidden_1),\n",
    "            \n",
    "            # 2nd hidden layer\n",
    "            torch.nn.Linear(num_hidden_1, num_hidden_2, bias=False),\n",
    "            torch.nn.LeakyReLU(),\n",
    "            torch.nn.Dropout(0.2),\n",
    "            torch.nn.BatchNorm1d(num_hidden_2),\n",
    "        )\n",
    "        \n",
    "        ### Specify CORAL layer\n",
    "        self.fc = CoralLayer(size_in=num_hidden_2, num_classes=num_classes)\n",
    "        ###--------------------------------------------------------------------###\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.my_network(x)\n",
    "        \n",
    "        ##### Use CORAL layer #####\n",
    "        logits =  self.fc(x)\n",
    "        ###--------------------------------------------------------------------###\n",
    "        \n",
    "        return logits\n",
    "    \n",
    "    \n",
    "    \n",
    "torch.manual_seed(random_seed)\n",
    "model = MLP(in_features=8, num_classes=NUM_CLASSES)\n",
    "model.to(DEVICE)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3 - Using the CORAL loss for model training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "During training, all you need to do is to \n",
    "\n",
    "1) convert the integer class labels into the extended binary label format using the `levels_from_labelbatch` provided via `coral_pytorch`:\n",
    "\n",
    "```python\n",
    "        levels = levels_from_labelbatch(class_labels, \n",
    "                                        num_classes=NUM_CLASSES)\n",
    "```\n",
    "\n",
    "2) Apply the CORAL loss (also provided via `coral_pytorch`):\n",
    "\n",
    "```python\n",
    "        loss = coral_loss(logits, levels)\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/050 | Batch 000/007 | Loss: 7.8762\n",
      "Epoch: 002/050 | Batch 000/007 | Loss: 5.2725\n",
      "Epoch: 003/050 | Batch 000/007 | Loss: 3.6569\n",
      "Epoch: 004/050 | Batch 000/007 | Loss: 2.3914\n",
      "Epoch: 005/050 | Batch 000/007 | Loss: 1.9764\n",
      "Epoch: 006/050 | Batch 000/007 | Loss: 1.8865\n",
      "Epoch: 007/050 | Batch 000/007 | Loss: 1.7322\n",
      "Epoch: 008/050 | Batch 000/007 | Loss: 1.5864\n",
      "Epoch: 009/050 | Batch 000/007 | Loss: 1.5017\n",
      "Epoch: 010/050 | Batch 000/007 | Loss: 1.3769\n",
      "Epoch: 011/050 | Batch 000/007 | Loss: 1.3414\n",
      "Epoch: 012/050 | Batch 000/007 | Loss: 1.3958\n",
      "Epoch: 013/050 | Batch 000/007 | Loss: 1.2516\n",
      "Epoch: 014/050 | Batch 000/007 | Loss: 1.1108\n",
      "Epoch: 015/050 | Batch 000/007 | Loss: 1.1921\n",
      "Epoch: 016/050 | Batch 000/007 | Loss: 1.1349\n",
      "Epoch: 017/050 | Batch 000/007 | Loss: 1.1515\n",
      "Epoch: 018/050 | Batch 000/007 | Loss: 1.1321\n",
      "Epoch: 019/050 | Batch 000/007 | Loss: 1.2286\n",
      "Epoch: 020/050 | Batch 000/007 | Loss: 1.0441\n",
      "Epoch: 021/050 | Batch 000/007 | Loss: 1.1615\n",
      "Epoch: 022/050 | Batch 000/007 | Loss: 0.9256\n",
      "Epoch: 023/050 | Batch 000/007 | Loss: 0.9923\n",
      "Epoch: 024/050 | Batch 000/007 | Loss: 0.9968\n",
      "Epoch: 025/050 | Batch 000/007 | Loss: 0.9609\n",
      "Epoch: 026/050 | Batch 000/007 | Loss: 0.9679\n",
      "Epoch: 027/050 | Batch 000/007 | Loss: 0.9891\n",
      "Epoch: 028/050 | Batch 000/007 | Loss: 1.0285\n",
      "Epoch: 029/050 | Batch 000/007 | Loss: 0.8547\n",
      "Epoch: 030/050 | Batch 000/007 | Loss: 0.8384\n",
      "Epoch: 031/050 | Batch 000/007 | Loss: 0.9744\n",
      "Epoch: 032/050 | Batch 000/007 | Loss: 0.8030\n",
      "Epoch: 033/050 | Batch 000/007 | Loss: 1.0487\n",
      "Epoch: 034/050 | Batch 000/007 | Loss: 0.8094\n",
      "Epoch: 035/050 | Batch 000/007 | Loss: 0.9045\n",
      "Epoch: 036/050 | Batch 000/007 | Loss: 0.9080\n",
      "Epoch: 037/050 | Batch 000/007 | Loss: 0.8050\n",
      "Epoch: 038/050 | Batch 000/007 | Loss: 0.8234\n",
      "Epoch: 039/050 | Batch 000/007 | Loss: 0.8839\n",
      "Epoch: 040/050 | Batch 000/007 | Loss: 0.7351\n",
      "Epoch: 041/050 | Batch 000/007 | Loss: 0.8924\n",
      "Epoch: 042/050 | Batch 000/007 | Loss: 0.8626\n",
      "Epoch: 043/050 | Batch 000/007 | Loss: 0.7739\n",
      "Epoch: 044/050 | Batch 000/007 | Loss: 0.7156\n",
      "Epoch: 045/050 | Batch 000/007 | Loss: 0.6445\n",
      "Epoch: 046/050 | Batch 000/007 | Loss: 1.0300\n",
      "Epoch: 047/050 | Batch 000/007 | Loss: 0.7190\n",
      "Epoch: 048/050 | Batch 000/007 | Loss: 0.7085\n",
      "Epoch: 049/050 | Batch 000/007 | Loss: 0.7748\n",
      "Epoch: 050/050 | Batch 000/007 | Loss: 0.6272\n"
     ]
    }
   ],
   "source": [
    "from coral_pytorch.dataset import levels_from_labelbatch\n",
    "from coral_pytorch.losses import coral_loss\n",
    "\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    \n",
    "    model = model.train()\n",
    "    for batch_idx, (features, class_labels) in enumerate(train_loader):\n",
    "\n",
    "        ##### Convert class labels for CORAL\n",
    "        levels = levels_from_labelbatch(class_labels, \n",
    "                                        num_classes=NUM_CLASSES)\n",
    "        ###--------------------------------------------------------------------###\n",
    "\n",
    "        features = features.to(DEVICE)\n",
    "        levels = levels.to(DEVICE)\n",
    "        logits = model(features)\n",
    "        \n",
    "        #### CORAL loss \n",
    "        loss = coral_loss(logits, levels)\n",
    "        ###--------------------------------------------------------------------###   \n",
    "        \n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 200:\n",
    "            print ('Epoch: %03d/%03d | Batch %03d/%03d | Loss: %.4f' \n",
    "                   %(epoch+1, num_epochs, batch_idx, \n",
    "                     len(train_loader), loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 4 -- Evaluate model\n",
    "\n",
    "Finally, after model training, we can evaluate the performance of the model. For example, via the mean absolute error and mean squared error measures.\n",
    "\n",
    "For this, we are going to use the `proba_to_label` utility function from `coral_pytorch` to convert the probabilities back to the orginal label.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from coral_pytorch.dataset import proba_to_label\n",
    "\n",
    "\n",
    "def compute_mae_and_mse(model, data_loader, device):\n",
    "\n",
    "    with torch.no_grad():\n",
    "    \n",
    "        mae, mse, acc, num_examples = 0., 0., 0., 0\n",
    "\n",
    "        for i, (features, targets) in enumerate(data_loader):\n",
    "\n",
    "            features = features.to(device)\n",
    "            targets = targets.float().to(device)\n",
    "\n",
    "            logits = model(features)\n",
    "            probas = torch.sigmoid(logits)\n",
    "            predicted_labels = proba_to_label(probas).float()\n",
    "\n",
    "            num_examples += targets.size(0)\n",
    "            mae += torch.sum(torch.abs(predicted_labels - targets))\n",
    "            mse += torch.sum((predicted_labels - targets)**2)\n",
    "\n",
    "        mae = mae / num_examples\n",
    "        mse = mse / num_examples\n",
    "        return mae, mse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_mae, train_mse = compute_mae_and_mse(model, train_loader, DEVICE)\n",
    "test_mae, test_mse = compute_mae_and_mse(model, test_loader, DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean absolute error (train/test): 0.28 | 0.36\n",
      "Mean squared error (train/test): 0.29 | 0.38\n"
     ]
    }
   ],
   "source": [
    "print(f'Mean absolute error (train/test): {train_mae:.2f} | {test_mae:.2f}')\n",
    "print(f'Mean squared error (train/test): {train_mse:.2f} | {test_mse:.2f}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
